# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import types
from dataclasses import dataclass

import torch

from torchao.core.config import AOBaseConfig
from torchao.prototype.mx_formats.config import (
    _validate_elem_dtype,
    _validate_kernel_preference,
)
from torchao.prototype.mx_formats.mx_tensor import (
    MXTensor,
    QuantizeTensorToMXKwargs,
    ScaleCalculationMode,
)
from torchao.prototype.mx_formats.nvfp4_tensor import (
    NVFP4Tensor,
    QuantizeTensorToNVFP4Kwargs,
    per_tensor_amax_to_scale,
)
from torchao.quantization.quant_api import _quantization_type
from torchao.quantization.quantize_.common.kernel_preference import KernelPreference
from torchao.quantization.transform_module import (
    register_quantize_module_handler,
)
from torchao.utils import (
    is_sm_at_least_100,
    torch_version_at_least,
)


@dataclass
class MXDynamicActivationMXWeightConfig(AOBaseConfig):
    """
    MX Format Inference Quantization

    This module provides support for running inference with float8 quantization using MX formats.

    Requirements:
    - NVIDIA SM100+ hardware (Blackwell or newer) is required for execution
    - PyTorch 2.5+ for proper serialization support
    """

    block_size: int = 32

    # Dtypes for Input and Weights, supports Fp8 and Fp4 formats
    activation_dtype: torch.dtype = torch.float8_e4m3fn
    weight_dtype: torch.dtype = torch.float8_e4m3fn

    # Which kernel to run for mm
    kernel_preference: KernelPreference = KernelPreference.AUTO

    def __post_init__(self):
        assert self.activation_dtype == self.weight_dtype, (
            "For now - we only support matching input/weight dtypes."
        )
        _validate_elem_dtype(self.activation_dtype)
        _validate_elem_dtype(self.weight_dtype)
        _validate_kernel_preference(
            self.kernel_preference, self.block_size, self.weight_dtype
        )


def _linear_extra_repr(self):
    return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={_quantization_type(self.weight)}"


@register_quantize_module_handler(MXDynamicActivationMXWeightConfig)
def _mx_inference_linear_transform(
    module: torch.nn.Module, config: MXDynamicActivationMXWeightConfig
):
    weight = module.weight

    assert weight.dtype == torch.bfloat16, (
        f"Only supporting bf16 out dtype for now, got {weight.dtype}"
    )
    act_quant_kwargs = QuantizeTensorToMXKwargs(
        elem_dtype=config.activation_dtype,
        block_size=config.block_size,
        kernel_preference=config.kernel_preference,
        is_swizzled_scales=True,
        scaling_mode=ScaleCalculationMode.RCEIL,
    )

    # Convert weight to MX Tensor
    quantized_weight = MXTensor.to_mx(
        weight,
        config.weight_dtype,
        block_size=config.block_size,
        kernel_preference=config.kernel_preference,
        act_quant_kwargs=act_quant_kwargs,
        is_swizzled_scales=True,
        scaling_mode=ScaleCalculationMode.RCEIL,
    )

    module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
    module.extra_repr = types.MethodType(_linear_extra_repr, module)
    return module


@dataclass
class NVFP4DynamicActivationNVFP4WeightConfig(AOBaseConfig):
    """
    NVIDIA FP4 (NVFP4) Inference Quantization Configuration

    This is a specialized configuration for NVIDIA's FP4 format.
    Configuration parameters:
    - use_triton_kernel: bool, whether to use fused triton kernel for activation scaling (default: True)
    - use_dynamic_per_tensor_scale: bool, whether to dynamically compute per tensor scale (default: True)
    - Data: float4_e2m1fn_x2
    - Scales: float8_e4m3fn
    - Block size: 16 along the reduction dim

    Note: Triton kernel only works with DYNAMIC mode and has constraints that input dimensions
    must satisfy M % 128 == 0 and K % 64 == 0. Will automatically fallback when constraints aren't met.
    """

    use_triton_kernel: bool = True
    use_dynamic_per_tensor_scale: bool = True

    def __post_init__(self):
        # Validate PyTorch version
        if not torch_version_at_least("2.8.0"):
            raise RuntimeError(
                "NVFP4DynamicActivationNVFP4WeightConfig requires PyTorch 2.8 or later"
            )


@register_quantize_module_handler(NVFP4DynamicActivationNVFP4WeightConfig)
def _nvfp4_inference_linear_transform(
    module: torch.nn.Linear, config: NVFP4DynamicActivationNVFP4WeightConfig
):
    """Quantization handler for NVFP4DynamicActivationNVFP4WeightConfig"""
    assert is_sm_at_least_100(), (
        "NVFP4 DYNAMIC mode is only supported on sm100+ machines"
    )

    weight = module.weight

    if weight.shape[-2] % 16 != 0 or weight.shape[-1] % 16 != 0:
        raise RuntimeError(
            f"NVFP4 only supports weight shape with last 2 dims divisible by 16, got {weight.shape}"
        )

    if module.bias is not None and weight.dtype == torch.float32:
        raise RuntimeError(
            "Bias is not supported when module weight is in fp32 (out_dtype=Float32). "
            "Please use bfloat16 or float16 weights, or remove the bias from the linear layer."
        )

    per_tensor_scale = None
    if config.use_dynamic_per_tensor_scale:
        tensor_amax = torch.max(torch.abs(weight))
        per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)

    act_quant_kwargs = QuantizeTensorToNVFP4Kwargs(
        use_dynamic_per_tensor_scale=config.use_dynamic_per_tensor_scale,
        use_triton_kernel=config.use_triton_kernel,
        is_swizzled_scales=True,
    )

    quantized_weight = NVFP4Tensor.to_nvfp4(
        weight,
        per_tensor_scale=per_tensor_scale,
        is_swizzled_scales=True,
        use_triton_kernel=False,  # Always use traditional construction for weights
        act_quant_kwargs=act_quant_kwargs,
    )
    # Set triton preference after construction
    quantized_weight.use_triton_kernel = config.use_triton_kernel
    module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
    module.extra_repr = types.MethodType(_linear_extra_repr, module)
    return module


@dataclass
class NVFP4WeightOnlyConfig(AOBaseConfig):
    use_dynamic_per_tensor_scale: bool = True

    def __post_init__(self):
        # Validate PyTorch version
        if not torch_version_at_least("2.8.0"):
            raise RuntimeError(
                "NVFP4DynamicActivationNVFP4WeightConfig requires PyTorch 2.8 or later"
            )


@register_quantize_module_handler(NVFP4WeightOnlyConfig)
def _nvfp4_weight_only_linear_transform(
    module: torch.nn.Linear, config: NVFP4WeightOnlyConfig
):
    """Quantization handler for NVFP4WeightOnlyConfig"""
    weight = module.weight

    if weight.shape[-2] % 16 != 0 or weight.shape[-1] % 16 != 0:
        raise RuntimeError(
            f"NVFP4 only supports weight shape with last 2 dims divisible by 16, got {weight.shape}"
        )

    per_tensor_scale = None
    if config.use_dynamic_per_tensor_scale:
        tensor_amax = torch.max(torch.abs(weight))
        per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)

    quantized_weight = NVFP4Tensor.to_nvfp4(
        weight,
        per_tensor_scale=per_tensor_scale,
        is_swizzled_scales=True,
        act_quant_kwargs=None,
    )
    # Set triton preference after construction
    module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
    module.extra_repr = types.MethodType(_linear_extra_repr, module)
    return module


torch.serialization.add_safe_globals(
    [
        MXTensor,
        NVFP4Tensor,
        QuantizeTensorToMXKwargs,
        QuantizeTensorToNVFP4Kwargs,
        ScaleCalculationMode,
    ]
)


import torch.nn as nn


def _auto_filter_for_nfp4(mod: nn.Module, fqn: str) -> bool:
    """Generic Filter fn for NVFP4 that is best practice for most models."""
    # Define any FQNs you want to exclude directly in the function
    filter_fqns = ["embedder", "embed", "embedding", "time_text_embed"]

    # Only support Linear modules
    if not isinstance(mod, nn.Linear):
        return False

    # If the fqn matches any filtered fqn, then we should not convert this module
    is_filtered_fqn = any(filter_fqn in fqn for filter_fqn in filter_fqns)
    if is_filtered_fqn:
        return False

    # All dims must be divisible by 16 due to float8 hardware requirements.
    N, K = mod.weight.shape
    dims_multiples_of_16 = K % 16 == 0 and N % 16 == 0
    if not dims_multiples_of_16:
        return False
    if N <= 64:
        print("skiping small linear layer")
        # TODO cublas doesn't like this one
        return False

    # Dims below these thresholds may result in worse performance
    if K <= 1024 and N <= 1024:
        print("skiping small linear layer")
        return False
    return True
